
package com.jstarcraft.ai.jsat.clustering.kmeans;

import static java.lang.Math.PI;
import static java.lang.Math.log;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import com.jstarcraft.ai.jsat.DataSet;
import com.jstarcraft.ai.jsat.SimpleDataSet;
import com.jstarcraft.ai.jsat.classifiers.DataPoint;
import com.jstarcraft.ai.jsat.clustering.SeedSelectionMethods;
import com.jstarcraft.ai.jsat.linear.MatrixStatistics;
import com.jstarcraft.ai.jsat.linear.Vec;

/**
 * This class provides a method of performing {@link KMeans} clustering when the
 * value of {@code K} is not known. It works by recursively splitting means up
 * to some specified maximum. value. <br>
 * <br>
 * When the value of {@code K} is specified, the implementation will simply call
 * the regular KMeans object it was constructed with. <br>
 * <br>
 * Note, that specifying a minimum value of {@code K=1} has a tendency to not be
 * split by the algorithm, returning the naive result of 1 cluster. It is better
 * to use at least {@code K=2} as the default minimum, which is what the
 * implementation will start from when no range of {@code K} is given. <br>
 * <br>
 * See: Pelleg, D.,&amp;Moore, A. (2000). <i>X-means: Extending K-means with
 * Efficient Estimation of the Number of Clusters</i>. In ICML (pp. 727–734).
 * San Francisco, CA, USA: Morgan Kaufmann Publishers Inc. Retrieved from
 * <a href=
 * "http://pdf.aminer.org/000/335/443/x_means_extending_k_means_with_efficient_estimation_of_the.pdf">
 * here</a>
 *
 * @author Edward Raff
 */
public class XMeans extends KMeans {
    private static final long serialVersionUID = -2577160317892141870L;
    private boolean stopAfterFail = false;
    private boolean iterativeRefine = true;

    private int minClusterSize = 25;
    private KMeans kmeans;

    public XMeans() {
        this(new HamerlyKMeans());
    }

    public XMeans(KMeans kmeans) {
        super(kmeans.dm, kmeans.seedSelection, kmeans.rand);
        this.kmeans = kmeans;
        this.kmeans.saveCentroidDistance = true;
        this.kmeans.setStoreMeans(true);
    }

    /**
     * Copy constructor
     * 
     * @param toCopy the object to copy
     */
    public XMeans(XMeans toCopy) {
        super(toCopy);
        this.kmeans = toCopy.kmeans.clone();
        this.stopAfterFail = toCopy.stopAfterFail;
        this.iterativeRefine = toCopy.iterativeRefine;
        this.minClusterSize = toCopy.minClusterSize;
    }

    /**
     * Each new cluster will be tested for improvement according to the BIC metric.
     * If this is set to {@code true} then an optimization is done that once a
     * center fails be improved by splitting, it will never be tested again. This is
     * a safe assumption when {@link #setIterativeRefine(boolean) } is set to
     * {@code false}, but otherwise may not quite be true. <br>
     * <br>
     * When {@code trustH0} is {@code true} , X-Means will make at most O(k) runs of
     * k-means for the final value of k chosen. When {@code false} (the default
     * option), at most O(k<sup>2</sup>) runs of k-means will occur.
     * 
     * @param stopAfterFail {@code true} if a centroid shouldn't be re-tested once
     *                      it fails to split.
     */
    public void setStopAfterFail(boolean stopAfterFail) {
        this.stopAfterFail = stopAfterFail;
    }

    /**
     * 
     * @return {@code true} if clusters that fail to split wont be re-tested.
     *         {@code false} if they will.
     */
    public boolean isStopAfterFail() {
        return stopAfterFail;
    }

    /**
     * Sets the minimum size for splitting a cluster.
     * 
     * @param minClusterSize the minimum number of data points that must be present
     *                       in a cluster to consider splitting it
     */
    public void setMinClusterSize(int minClusterSize) {
        if (minClusterSize < 2)
            throw new IllegalArgumentException("min cluster size that could be split is 2, not " + minClusterSize);
        this.minClusterSize = minClusterSize;
    }

    /**
     * 
     * @return the minimum number of data points that must be present in a cluster
     *         to consider splitting it
     */
    public int getMinClusterSize() {
        return minClusterSize;
    }

    /**
     * Sets whether or not the set of all cluster centers should be refined at every
     * iteration. By default this is {@code true} and part of how the X-Means
     * algorithm is described. Setting this to {@code false} can result in large
     * speedups at the potential cost of quality.
     * 
     * @param refineCenters {@code true} to refine the cluster centers at every
     *                      step, {@code false} to skip this step of the algorithm.
     */
    public void setIterativeRefine(boolean refineCenters) {
        this.iterativeRefine = refineCenters;
    }

    /**
     * 
     * @return {@code true} if the cluster centers are refined at every step,
     *         {@code false} if skipping this step of the algorithm.
     */
    public boolean getIterativeRefine() {
        return iterativeRefine;
    }

    @Override
    public int[] cluster(DataSet dataSet, int[] designations) {
        return cluster(dataSet, 2, Math.max(dataSet.size() / 20, 10), designations);
    }

    @Override
    public int[] cluster(DataSet dataSet, boolean parallel, int[] designations) {
        return cluster(dataSet, 2, Math.max(dataSet.size() / 20, 10), parallel, designations);
    }

    /**
     * "p_j is simply the sum of K- 1 class probabilities, M * K centroid
     * coordinates, and one variance estimate."
     * 
     * @param K the number of clusters
     * @param D the number of dimensions
     * @return the number of free parameters
     */
    private static int freeParameters(int K, int D) {
        return (K - 1) + (D * K) + 1;
    }

    @Override
    public int[] cluster(DataSet dataSet, int lowK, int highK, boolean parallel, int[] designations) {
        final int N = dataSet.size();
        final int D = dataSet.getNumNumericalVars();// "M" in orig paper

        if (designations == null || designations.length < dataSet.size())
            designations = new int[N];

        List<Vec> data = dataSet.getDataVectors();
        final List<Double> accelCache = dm.getAccelerationCache(data, parallel);

        /**
         * The sum of ||x - \mu_i||^2 for each cluster currently kept
         */
        double[] localVar = new double[highK];
        int[] localOwned = new int[highK];
        // initiate
        if (lowK >= 2) {
            means = new ArrayList<Vec>();
            kmeans.cluster(dataSet, accelCache, lowK, means, designations, true, parallel, true, null);
            for (int i = 0; i < data.size(); i++) {
                localVar[designations[i]] += Math.pow(kmeans.nearestCentroidDist[i], 2);
                localOwned[designations[i]]++;
            }
        } else// 1 mean of all the data
        {
            if (designations == null || designations.length < N)
                designations = new int[N];
            else
                Arrays.fill(designations, 0);
            means = new ArrayList<>(Arrays.asList(MatrixStatistics.meanVector(dataSet)));
            localOwned[0] = N;
            List<Double> qi = dm.getQueryInfo(means.get(0));
            for (int i = 0; i < data.size(); i++)
                localVar[0] += Math.pow(dm.dist(i, means.get(0), qi, data, accelCache), 2);
        }

        int[] subS = new int[designations.length];
        int[] subC = new int[designations.length];

        // tract if we should stop testing a mean or not
        List<Boolean> dontRedo = new ArrayList<>(Collections.nCopies(means.size(), false));

        int origMeans;
        do {
            origMeans = means.size();

            for (int c = 0; c < origMeans; c++) {
                if (dontRedo.get(c))
                    continue;
                /*
                 * Next, in each parent region we run a local K-means (with K = 2) for each pair
                 * of children. It is local in that the children are fighting each other for the
                 * points in the parent's region: no others
                 */

                List<DataPoint> X = getDatapointsFromCluster(c, designations, dataSet, subS);
                final int n = X.size();// NOTE, not the same as N. PAY ATENTION
                // TODO add the optimization in the paper where we check for movment, and dont
                // test means that haven't mvoed much
                if (X.size() < minClusterSize || means.size() == highK)
                    continue;// this loop with force it to exit when we hit max K

                subC = kmeans.cluster(new SimpleDataSet(X), 2, parallel, subC);
                // call explicitly to force that distance to nearest center is saved
                List<Vec> subMean = new ArrayList<>(2);
                kmeans.cluster(new SimpleDataSet(X), null, 2, subMean, subC, true, parallel, true, null);
                double[] nearDist = kmeans.nearestCentroidDist;
                Vec c1 = subMean.get(0);
                Vec c2 = subMean.get(1);

                /*
                 * "it determines which one to explore by improving the BIC locally in each
                 * region." so we only compute BIC from local information
                 */

                double newSigma = 0;
                int size_c1 = 0;
                for (int i = 0; i < X.size(); i++) {
                    newSigma += Math.pow(nearDist[i], 2);
                    if (subC[i] == 0)
                        size_c1++;
                }
                newSigma /= D * (n - 2);
                int size_c2 = n - size_c1;

                // have needed values, now compute BIC for LOCAL models

                double localNewBic = size_c1 * log(size_c1) + size_c2 * log(size_c2) - n * log(n) - n * D / 2.0 * log(2 * PI * newSigma) - D / 2.0 * (n - 2)// that gets us the log like, last line to penalize for bic
                        - freeParameters(2, D) / 2.0 * log(n);

                double localOldBic = -n * D / 2.0 * log(2 * PI * localVar[c] / (D * (n - 1))) - D / 2.0 * (n - 1)// that gets us the log like, last line to penalize for bic
                        - freeParameters(1, D) / 2.0 * log(n);

                if (localOldBic > localNewBic) {
                    if (stopAfterFail)// if we are going to trust that H0 is true forever, mark it
                        dontRedo.set(c, true);
                    continue;// passed the test, do not split
                }
                // else, accept the split

                // first, update assignment array. Cluster '0' stays as is, re-set cluster '1'
                for (int i = 0; i < X.size(); i++)
                    if (subC[i] == 1)
                        designations[subS[i]] = means.size();
                // replace current mean and add new one
                means.set(c, c1.clone());// cur index in dontRedo stays false
                means.add(c2.clone());// add a 'false' for new center
                dontRedo.add(false);
            }
            // "Between each round of splitting, we run k-means on the entire dataset and
            // all the centers to refine the current solution"
            if (iterativeRefine && means.size() > 1) {
                kmeans.cluster(dataSet, accelCache, means.size(), means, designations, true, parallel, true, null);
                Arrays.fill(localVar, 0.0);
                Arrays.fill(localOwned, 0);
                for (int i = 0; i < data.size(); i++) {
                    localVar[designations[i]] += Math.pow(kmeans.nearestCentroidDist[i], 2);
                    localOwned[designations[i]]++;
                }
            }
        } while (origMeans < means.size());

        if (!iterativeRefine)// if we havn't been refining we need to do so now!
            kmeans.cluster(dataSet, accelCache, means.size(), means, designations, false, parallel, false, null);
        return designations;
    }

    @Override
    public int getIterationLimit() {
        return kmeans.getIterationLimit();
    }

    @Override
    public void setIterationLimit(int iterLimit) {
        kmeans.setIterationLimit(iterLimit);
    }

    @Override
    public void setSeedSelection(SeedSelectionMethods.SeedSelection seedSelection) {
        if (kmeans != null)// needed when initing
            kmeans.setSeedSelection(seedSelection);
    }

    @Override
    public SeedSelectionMethods.SeedSelection getSeedSelection() {
        return kmeans.getSeedSelection();
    }

    @Override
    protected double cluster(DataSet dataSet, List<Double> accelCache, int k, List<Vec> means, int[] assignment, boolean exactTotal, boolean threadpool, boolean returnError, Vec dataPointWeights) {
        return kmeans.cluster(dataSet, accelCache, k, means, assignment, exactTotal, threadpool, returnError, null);
    }

    @Override
    public XMeans clone() {
        return new XMeans(this);
    }
}
